{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.9906159 ,  0.13667548,  0.8457832 ], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAirUlEQVR4nO3df3DU9YH/8dduNruQH7shQHbNkQhXmGKGHyogbL253pWUtJdra6UznuW8nGV05AIH0nEqPcVppzNh9HvX1jvFu7k59Q+VGzrFVk5qcwHjtYZfEWxATbWHJBU2QWh2k0A2ye77+4dlz8VoE/LJ7nuT52NmZ+Tzee877/0I+8xmP/msyxhjBACAhdzZXgAAAB+HSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArJW1SD322GOaO3eupk2bppUrV+rw4cPZWgoAwFJZidR//ud/auvWrXrooYf02muvaenSpaqpqVF3d3c2lgMAsJQrGxeYXblypVasWKF/+Zd/kSQlk0lVVFRo06ZNuv/++zO9HACApTyZ/oKDg4NqbW3Vtm3bUtvcbreqq6vV0tIy4n3i8bji8Xjqz8lkUhcuXNDMmTPlcrkmfM0AAGcZY9Tb26vy8nK53R//Q72MR+r9999XIpFQMBhM2x4MBvXWW2+NeJ+GhgZ95zvfycTyAAAZ1NnZqTlz5nzs/oxH6mps27ZNW7duTf05Go2qsrJSnZ2d8vv9WVwZAOBqxGIxVVRUqLi4+BPHZTxSs2bNUl5enrq6utK2d3V1KRQKjXgfn88nn8/3ke1+v59IAUAO+0Nv2WT87D6v16tly5apqakptS2ZTKqpqUnhcDjTywEAWCwrP+7bunWr6urqtHz5ct100036wQ9+oP7+ft15553ZWA4AwFJZidRtt92mc+fOafv27YpEIrr++uv1s5/97CMnUwAApras/J7UeMViMQUCAUWjUd6TAoAcNNrnca7dBwCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaY47UK6+8oi996UsqLy+Xy+XS888/n7bfGKPt27frmmuu0fTp01VdXa233347bcyFCxe0bt06+f1+lZSUaP369err6xvXAwEATD5jjlR/f7+WLl2qxx57bMT9Dz/8sB599FE98cQTOnTokAoLC1VTU6OBgYHUmHXr1unkyZNqbGzU3r179corr+juu++++kcBAJiczDhIMnv27En9OZlMmlAoZB555JHUtp6eHuPz+cxzzz1njDHmjTfeMJLMkSNHUmP27dtnXC6Xee+990b1daPRqJFkotHoeJYPAMiS0T6PO/qe1KlTpxSJRFRdXZ3aFggEtHLlSrW0tEiSWlpaVFJSouXLl6fGVFdXy+1269ChQyPOG4/HFYvF0m4AgMnP0UhFIhFJUjAYTNseDAZT+yKRiMrKytL2ezwelZaWpsZcqaGhQYFAIHWrqKhwctkAAEvlxNl927ZtUzQaTd06OzuzvSQAQAY4GqlQKCRJ6urqStve1dWV2hcKhdTd3Z22f3h4WBcuXEiNuZLP55Pf70+7AQAmP0cjNW/ePIVCITU1NaW2xWIxHTp0SOFwWJIUDofV09Oj1tbW1Jj9+/crmUxq5cqVTi4HAJDjPGO9Q19fn955553Un0+dOqXjx4+rtLRUlZWV2rJli773ve9pwYIFmjdvnh588EGVl5frlltukSRdd911+sIXvqC77rpLTzzxhIaGhrRx40b91V/9lcrLyx17YACASWCspw0eOHDASPrIra6uzhjzwWnoDz74oAkGg8bn85nVq1eb9vb2tDnOnz9vbr/9dlNUVGT8fr+58847TW9vr+OnLgIA7DTa53GXMcZksZFXJRaLKRAIKBqN8v4UAOSg0T6P58TZfQCAqYlIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrTFFqqGhQStWrFBxcbHKysp0yy23qL29PW3MwMCA6uvrNXPmTBUVFWnt2rXq6upKG9PR0aHa2loVFBSorKxM9913n4aHh8f/aAAAk8qYItXc3Kz6+nodPHhQjY2NGhoa0po1a9Tf358ac++99+qFF17Q7t271dzcrDNnzujWW29N7U8kEqqtrdXg4KBeffVVPf3003rqqae0fft25x4VAGByMOPQ3d1tJJnm5mZjjDE9PT0mPz/f7N69OzXmzTffNJJMS0uLMcaYF1980bjdbhOJRFJjdu7cafx+v4nH46P6utFo1Egy0Wh0PMsHAGTJaJ/Hx/WeVDQalSSVlpZKklpbWzU0NKTq6urUmIULF6qyslItLS2SpJaWFi1evFjBYDA1pqamRrFYTCdPnhzx68TjccVisbQbAGDyu+pIJZNJbdmyRTfffLMWLVokSYpEIvJ6vSopKUkbGwwGFYlEUmM+HKjL+y/vG0lDQ4MCgUDqVlFRcbXLBgDkkKuOVH19vU6cOKFdu3Y5uZ4Rbdu2TdFoNHXr7Oyc8K8JAMg+z9XcaePGjdq7d69eeeUVzZkzJ7U9FAppcHBQPT09aa+murq6FAqFUmMOHz6cNt/ls/8uj7mSz+eTz+e7mqUCAHLYmF5JGWO0ceNG7dmzR/v379e8efPS9i9btkz5+flqampKbWtvb1dHR4fC4bAkKRwOq62tTd3d3akxjY2N8vv9qqqqGs9jAQBMMmN6JVVfX69nn31WP/nJT1RcXJx6DykQCGj69OkKBAJav369tm7dqtLSUvn9fm3atEnhcFirVq2SJK1Zs0ZVVVW644479PDDDysSieiBBx5QfX09r5YAAGlcxhgz6sEu14jbn3zySf3t3/6tpA9+mfeb3/ymnnvuOcXjcdXU1Ojxxx9P+1He6dOntWHDBr388ssqLCxUXV2dduzYIY9ndM2MxWIKBAKKRqPy+/2jXT4AwBKjfR4fU6RsQaQAILeN9nmca/cBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrebK9AAD/xxjzsftcLlcGVwLYgUgBFjCJhIZ7exV77TX1HDmigc5OJS5dksfvV+H8+ZrxJ3+igk99SnmFhcQKUwqRArIsGY+r5+BBdb3wgi6+/bb0oVdTQ+fO6dJvfqPzBw4ocOONKrvlFhVddx2hwpRBpIAsMsbo3M9/rsju3Rru6fn4cYOD6jl4UANnz6ry7rtVtGgRocKUwIkTQJaYRELn//u/deaZZz4xUB82cPq0Ov7t39T35puf+P4VMFkQKSBL+n/9a0V271by4sW07e/192tvZ6ee+9//1X+fOaP+oaG0/QOnT+vsc88p0deXyeUCWcGP+4AsSA4NKXr0qOKRSGqbMUan+vr00LFjerevTwOJhPz5+Vo0Y4b+34oVynf/3/eUva+/rou/+Y2Kly7lx36Y1HglBWTB0Pnz6vrxj9O2/W9fn+765S/1ZjSqS4mEjKTo0JB+2d2tzYcO6fzAQNr4jieeyOCKgewgUkCGGWN0vqlJJpFI2/6DkycVveJHe5cdfv99NZ45kz7Px4wFJhMiBWRYor9f5156KdvLAHICkQIybLC7+yOvogCMjEgBGRZ7/XUlr3h/SZJqKyqU/zEnQcwtKtKS0tKJXhpgHc7uAzLs0rvvjvh+Uk15uSTpe6+/rsFEQklJeS6XSrxe/eOKFbq2qCjDKwWyj0gBlnC5XKopL9ecggLt/e1vdX5gQHOLinTbvHma6fN9ZLw3GMzCKoHMIlJABplEQkomP3a/y+XSohkztGjGjD8418w//3MnlwZYifekgAxKXLyoxBVXmLhavt//eBCYzIgUkEHDfX0aduhyRt5ZsxyZB7AZkQIyKP7eexr47W+dmczl4pJImPSIFJBBw7GYEr29457H7fMRKEwJRArIQSU336w8TknHFECkgAwxyaQSI/wS79Xwzpoll4eTczH5ESkgQ0wiocFz5xyZyztzJpHClECkgAxJxuPqO3HCkblceXm8J4UpgUgBGZKMx9Xf3u7MZAQKUwSRAnLM9HnzVLBgQbaXAWQEkQIyxKmP5/AUFys/EHBkLsB2RArIkKHf/c6RefIKCzn9HFMGkQIyJH72rCPzuDwezuzDlEGkgAyJHjni2Fyc2YepgkgBGXLxnXfGPYfL65X/xhsdWA2QG4gUkEPc+fkq+OM/zvYygIwhUkAGJONxmU/4sMNRc7vlnT17/PMAOYJIARkweOGCzPDwuOdxud3KKyhwYEVAbiBSQAZcfOcdJS5dcmYyTprAFEKkgAzobWtT0oGPjS+qqnJgNUDuIFJADim+/vpsLwHIKCIFTDBjjGSMI3NN+6M/cmQeIFcQKWCCJeNxJR36sEOu2YephkgBEyxx8aKGe3udmczt5moTmFKIFDDBBs+d08CZM+OeJ3/2bLm9XgdWBOQOIgVMsKHz5zUYiYx7nqKqKq5+jimHSAE5wjtrltz5+dleBpBRRAqYQMYYJQcHHZkrPxCQi0hhiiFSwERKJjV04YIjU7ny8jhpAlPOmCK1c+dOLVmyRH6/X36/X+FwWPv27UvtHxgYUH19vWbOnKmioiKtXbtWXV1daXN0dHSotrZWBQUFKisr03333adhB65pBtjIDA878xEd+flyT5/uwIqA3DKmSM2ZM0c7duxQa2urjh49qs997nP6yle+opMnT0qS7r33Xr3wwgvavXu3mpubdebMGd16662p+ycSCdXW1mpwcFCvvvqqnn76aT311FPavn27s48KsERycFC/++Uvxz2Pr7xcBfPnO7AiILe4jBnfr8KXlpbqkUce0de+9jXNnj1bzz77rL72ta9Jkt566y1dd911amlp0apVq7Rv3z795V/+pc6cOaNgMChJeuKJJ/Stb31L586dk3eUp9fGYjEFAgFFo1H5/f7xLB+YUMO9vXr9r/963FecKLzuOv3xt74lb2mpQysDsmu0z+NX/Z5UIpHQrl271N/fr3A4rNbWVg0NDam6ujo1ZuHChaqsrFRLS4skqaWlRYsXL04FSpJqamoUi8VSr8ZGEo/HFYvF0m5ALhjn94ApedOny8M3ZJiCxhyptrY2FRUVyefz6Z577tGePXtUVVWlSCQir9erkpKStPHBYFCR3/+OSCQSSQvU5f2X932choYGBQKB1K2iomKsywayIuHQlSZceXlyezyOzAXkkjFH6tOf/rSOHz+uQ4cOacOGDaqrq9Mbb7wxEWtL2bZtm6LRaOrW2dk5oV8PcEo8EnHk4rKuvDwHVgPknjF/a+b1ejX/92/gLlu2TEeOHNEPf/hD3XbbbRocHFRPT0/aq6muri6FQiFJUigU0uHDh9Pmu3z23+UxI/H5fPL5fGNdKpB1fQ58A+fKz1fpn/3Z+BcD5KBx/55UMplUPB7XsmXLlJ+fr6amptS+9vZ2dXR0KBwOS5LC4bDa2trU3d2dGtPY2Ci/368qPswNk1D0im/KrobL7Za3rMyB1QC5Z0yvpLZt26YvfvGLqqysVG9vr5599lm9/PLLeumllxQIBLR+/Xpt3bpVpaWl8vv92rRpk8LhsFatWiVJWrNmjaqqqnTHHXfo4YcfViQS0QMPPKD6+npeKQEfx+2W74r3coGpYkyR6u7u1t/8zd/o7NmzCgQCWrJkiV566SV9/vOflyR9//vfl9vt1tq1axWPx1VTU6PHH388df+8vDzt3btXGzZsUDgcVmFhoerq6vTd737X2UcFWMAMDztzdp/LpbyCgvHPA+Sgcf+eVDbwe1LIBYMXLujX3/624uP8mA53QYGuf+YZTp7ApDLhvycF4JMNXbjgyMVleRWFqYxIARPk0unTSly8OO55grfeKrn5p4qpib/5wAS5dPq0kg5EahonTWAKI1LABDDGOPJLvJLkueIqLsBUQqSACWASCRmHPuxQLhefI4Upi0gBEyB56ZKGHbpuHzCVESlgAgzHYopf8YGfV6PkM5+Rj6tNYAojUsAEiHd16eLbb497Hu+sWXJPm+bAioDcRKQAi3mKi+XKz8/2MoCsIVKAw4wxUjLpzGRuNydNYEojUoDTjNHQ736X7VUAkwKRApxmjAY/9HE0V8tbVqYiPsIGUxyRAhxmEgn1O3DShMfv17TycgdWBOQuIgU4zCQS6n399XHP4/b5lFdc7MCKgNxFpABLudxuuT1j+sg3YNIhUoDDkvG4cu5D2gBLESnAYfHu7vFfXNbtVvGSJc4sCMhhRApw2EBHx7gj5XK75b/hBodWBOQuIgU47P2XXhr/KymXS95QyJkFATmMSAE2crmUxzX7ACIFOMnJDzsEQKQARyX6+5UcGhr3PNOvvVZy888T4F8B4KDhaFRJBz6Rd8bNN8tFpAAiBTgpHokocfHiuOfxXXONxNXPASIFOKn/rbc07MAV0D1+Px/RAYhIAY4xnDABOI5IAU4xRiaRGPc0eYWFcvNpvIAkIgU4JhmPazgWG/c8xUuXyhsMOrAiIPcRKcAhyYEBRz6RN7+khF/kBX6PSAEOGThzRtGjR8c9T15RkVxerwMrAnIfkQKc5MTJEy4XZ/YBv0ekAAc4djkkl0uuvLzxzwNMEkQKcMhwb++45/CWlalkxQoHVgNMDkQKcIIxGuzuHvc07mnT5Jkxw4EFAZMDkQKcYIwGfvvbcU/j9nqVX1Iy/vUAkwSRAhxgkkmdP3Bg/BO5XFxYFvgQ/jUAAKxFpABbuFya9fnPZ3sVgFWIFGALl0vT587N9ioAqxApwBYul3xcsw9IQ6QAh+QHAuO7v98vN9fsA9IQKcABLo9Hf1RXN645yu+4Q26u2QekIVKAA1wul4qqqlR8/fVXdf+iRYtUvHgxp58DV+BfBOCQ/JkzVf71r2v6vHljut+0igqVf/3r8paVTdDKgNxFpACHuFwuFS1cqDnr12vatdeO6j7T5szRnG98Q8WLFnHlc2AERApwWPGiRZq7aZNm/OmffnAixJXxcbnk8no14+abde3f/738N9yQnYUCOcCT7QUAk43L7VbBggWau2WLLr37rnpaWnSpo0OJvj7lFRZqemWlAitXquBTn5IrL49XUMAnIFLABHC5XHJ5PCqcP1+F8+dnezlAzuLHfQAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsNa5I7dixQy6XS1u2bEltGxgYUH19vWbOnKmioiKtXbtWXV1daffr6OhQbW2tCgoKVFZWpvvuu0/Dw8PjWQoAYBK66kgdOXJE//qv/6olS5akbb/33nv1wgsvaPfu3WpubtaZM2d06623pvYnEgnV1tZqcHBQr776qp5++mk99dRT2r59+9U/CgDA5GSuQm9vr1mwYIFpbGw0n/3sZ83mzZuNMcb09PSY/Px8s3v37tTYN99800gyLS0txhhjXnzxReN2u00kEkmN2blzp/H7/SYej4/q60ejUSPJRKPRq1k+ACDLRvs8flWvpOrr61VbW6vq6uq07a2trRoaGkrbvnDhQlVWVqqlpUWS1NLSosWLFysYDKbG1NTUKBaL6eTJkyN+vXg8rlgslnYDAEx+nrHeYdeuXXrttdd05MiRj+yLRCLyer0qKSlJ2x4MBhWJRFJjPhyoy/sv7xtJQ0ODvvOd74x1qQCAHDemV1KdnZ3avHmznnnmGU2bNm2i1vQR27ZtUzQaTd06Ozsz9rUBANkzpki1traqu7tbN954ozwejzwej5qbm/Xoo4/K4/EoGAxqcHBQPT09affr6upSKBSSJIVCoY+c7Xf5z5fHXMnn88nv96fdAACT35gitXr1arW1ten48eOp2/Lly7Vu3brUf+fn56upqSl1n/b2dnV0dCgcDkuSwuGw2tra1N3dnRrT2Ngov9+vqqoqhx4WAGAyGNN7UsXFxVq0aFHatsLCQs2cOTO1ff369dq6datKS0vl9/u1adMmhcNhrVq1SpK0Zs0aVVVV6Y477tDDDz+sSCSiBx54QPX19fL5fA49LADAZDDmEyf+kO9///tyu91au3at4vG4ampq9Pjjj6f25+Xlae/evdqwYYPC4bAKCwtVV1en7373u04vBQCQ41zGGJPtRYxVLBZTIBBQNBrl/SkAyEGjfR7n2n0AAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGt5sr2Aq2GMkSTFYrEsrwQAcDUuP39ffj7/ODkZqfPnz0uSKioqsrwSAMB49Pb2KhAIfOz+nIxUaWmpJKmjo+MTH9xUF4vFVFFRoc7OTvn9/mwvx1ocp9HhOI0Ox2l0jDHq7e1VeXn5J47LyUi53R+8lRYIBPhLMAp+v5/jNAocp9HhOI0Ox+kPG82LDE6cAABYi0gBAKyVk5Hy+Xx66KGH5PP5sr0Uq3GcRofjNDocp9HhODnLZf7Q+X8AAGRJTr6SAgBMDUQKAGAtIgUAsBaRAgBYKycj9dhjj2nu3LmaNm2aVq5cqcOHD2d7SRn1yiuv6Etf+pLKy8vlcrn0/PPPp+03xmj79u265pprNH36dFVXV+vtt99OG3PhwgWtW7dOfr9fJSUlWr9+vfr6+jL4KCZWQ0ODVqxYoeLiYpWVlemWW25Re3t72piBgQHV19dr5syZKioq0tq1a9XV1ZU2pqOjQ7W1tSooKFBZWZnuu+8+DQ8PZ/KhTKidO3dqyZIlqV88DYfD2rdvX2o/x2hkO3bskMvl0pYtW1LbOFYTxOSYXbt2Ga/Xa/7jP/7DnDx50tx1112mpKTEdHV1ZXtpGfPiiy+af/iHfzA//vGPjSSzZ8+etP07duwwgUDAPP/88+b11183X/7yl828efPMpUuXUmO+8IUvmKVLl5qDBw+a//mf/zHz5883t99+e4YfycSpqakxTz75pDlx4oQ5fvy4+Yu/+AtTWVlp+vr6UmPuueceU1FRYZqamszRo0fNqlWrzGc+85nU/uHhYbNo0SJTXV1tjh07Zl588UUza9Yss23btmw8pAnx05/+1PzXf/2X+fWvf23a29vNt7/9bZOfn29OnDhhjOEYjeTw4cNm7ty5ZsmSJWbz5s2p7RyriZFzkbrppptMfX196s+JRMKUl5ebhoaGLK4qe66MVDKZNKFQyDzyyCOpbT09Pcbn85nnnnvOGGPMG2+8YSSZI0eOpMbs27fPuFwu895772Vs7ZnU3d1tJJnm5mZjzAfHJD8/3+zevTs15s033zSSTEtLizHmg28G3G63iUQiqTE7d+40fr/fxOPxzD6ADJoxY4b593//d47RCHp7e82CBQtMY2Oj+exnP5uKFMdq4uTUj/sGBwfV2tqq6urq1Da3263q6mq1tLRkcWX2OHXqlCKRSNoxCgQCWrlyZeoYtbS0qKSkRMuXL0+Nqa6ultvt1qFDhzK+5kyIRqOS/u/ixK2trRoaGko7TgsXLlRlZWXacVq8eLGCwWBqTE1NjWKxmE6ePJnB1WdGIpHQrl271N/fr3A4zDEaQX19vWpra9OOicTfp4mUUxeYff/995VIJNL+J0tSMBjUW2+9laVV2SUSiUjSiMfo8r5IJKKysrK0/R6PR6Wlpakxk0kymdSWLVt08803a9GiRZI+OAZer1clJSVpY688TiMdx8v7Jou2tjaFw2ENDAyoqKhIe/bsUVVVlY4fP84x+pBdu3bptdde05EjRz6yj79PEyenIgVcjfr6ep04cUK/+MUvsr0UK33605/W8ePHFY1G9aMf/Uh1dXVqbm7O9rKs0tnZqc2bN6uxsVHTpk3L9nKmlJz6cd+sWbOUl5f3kTNmurq6FAqFsrQqu1w+Dp90jEKhkLq7u9P2Dw8P68KFC5PuOG7cuFF79+7VgQMHNGfOnNT2UCikwcFB9fT0pI2/8jiNdBwv75ssvF6v5s+fr2XLlqmhoUFLly7VD3/4Q47Rh7S2tqq7u1s33nijPB6PPB6Pmpub9eijj8rj8SgYDHKsJkhORcrr9WrZsmVqampKbUsmk2pqalI4HM7iyuwxb948hUKhtGMUi8V06NCh1DEKh8Pq6elRa2trasz+/fuVTCa1cuXKjK95IhhjtHHjRu3Zs0f79+/XvHnz0vYvW7ZM+fn5acepvb1dHR0dacepra0tLeiNjY3y+/2qqqrKzAPJgmQyqXg8zjH6kNWrV6utrU3Hjx9P3ZYvX65169al/ptjNUGyfebGWO3atcv4fD7z1FNPmTfeeMPcfffdpqSkJO2Mmcmut7fXHDt2zBw7dsxIMv/0T/9kjh07Zk6fPm2M+eAU9JKSEvOTn/zE/OpXvzJf+cpXRjwF/YYbbjCHDh0yv/jFL8yCBQsm1SnoGzZsMIFAwLz88svm7NmzqdvFixdTY+655x5TWVlp9u/fb44ePWrC4bAJh8Op/ZdPGV6zZo05fvy4+dnPfmZmz549qU4Zvv/++01zc7M5deqU+dWvfmXuv/9+43K5zM9//nNjDMfok3z47D5jOFYTJeciZYwx//zP/2wqKyuN1+s1N910kzl48GC2l5RRBw4cMJI+cqurqzPGfHAa+oMPPmiCwaDx+Xxm9erVpr29PW2O8+fPm9tvv90UFRUZv99v7rzzTtPb25uFRzMxRjo+ksyTTz6ZGnPp0iXzd3/3d2bGjBmmoKDAfPWrXzVnz55Nm+fdd981X/ziF8306dPNrFmzzDe/+U0zNDSU4Uczcb7xjW+Ya6+91ni9XjN79myzevXqVKCM4Rh9kisjxbGaGHxUBwDAWjn1nhQAYGohUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFr/H2UUg3pYDeFnAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.2205],\n",
       "         [-1.1850]], grad_fn=<MulBackward0>),\n",
       " tensor([[0.4726],\n",
       "         [0.5359]], grad_fn=<NegBackward0>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "class ModelAction(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc_state = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 128),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "        self.fc_mu = torch.nn.Linear(128, 1)\n",
    "        self.fc_std = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Softplus(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        #[b, 3] -> [b, 128]\n",
    "        state = self.fc_state(state)\n",
    "\n",
    "        #[b, 128] -> [b, 1]\n",
    "        mu = self.fc_mu(state)\n",
    "\n",
    "        #[b, 128] -> [b, 1]\n",
    "        std = self.fc_std(state)\n",
    "\n",
    "        #根据mu和std定义b个正态分布\n",
    "        dist = torch.distributions.Normal(mu, std)\n",
    "\n",
    "        #采样b个样本\n",
    "        #这里用的是rsample,表示重采样,其实就是先从一个标准正态分布中采样,然后乘以标准差,加上均值\n",
    "        sample = dist.rsample()\n",
    "\n",
    "        #样本压缩到-1,1之间,求动作\n",
    "        action = torch.tanh(sample)\n",
    "\n",
    "        #求概率对数\n",
    "        log_prob = dist.log_prob(sample)\n",
    "\n",
    "        #这个值描述动作的熵\n",
    "        entropy = log_prob - (1 - action.tanh()**2 + 1e-7).log()\n",
    "        entropy = -entropy\n",
    "\n",
    "        return action * 2, entropy\n",
    "\n",
    "\n",
    "model_action = ModelAction()\n",
    "\n",
    "model_action(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1757],\n",
       "        [0.0340]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class ModelValue(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.sequential = torch.nn.Sequential(\n",
    "            torch.nn.Linear(4, 128),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(128, 128),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(128, 1),\n",
    "        )\n",
    "\n",
    "    def forward(self, state, action):\n",
    "        #[b, 3+1] -> [b, 4]\n",
    "        state = torch.cat([state, action], dim=1)\n",
    "\n",
    "        #[b, 4] -> [b, 1]\n",
    "        return self.sequential(state)\n",
    "\n",
    "\n",
    "model_value1 = ModelValue()\n",
    "model_value2 = ModelValue()\n",
    "\n",
    "model_value_next1 = ModelValue()\n",
    "model_value_next2 = ModelValue()\n",
    "\n",
    "model_value_next1.load_state_dict(model_value1.state_dict())\n",
    "model_value_next2.load_state_dict(model_value2.state_dict())\n",
    "\n",
    "model_value1(torch.randn(2, 3), torch.randn(2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.13071702420711517"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    action, _ = model_action(state)\n",
    "    return action.item()\n",
    "\n",
    "\n",
    "get_action([1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(200,\n",
       " (array([-0.05687417, -0.9983814 ,  0.74806935], dtype=float32),\n",
       "  0.27284643054008484,\n",
       "  -2.705446445635012,\n",
       "  array([-0.0548668 , -0.9984937 ,  0.04021031], dtype=float32),\n",
       "  False))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "        #记录数据样本\n",
    "        datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "        #更新游戏状态,开始下一个动作\n",
    "        state = next_state\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 100000:\n",
    "        datas.pop(0)\n",
    "\n",
    "\n",
    "update_data()\n",
    "\n",
    "len(datas), datas[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\cgq10\\AppData\\Local\\Temp\\ipykernel_11612\\1710091499.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:248.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.1539,  0.9881,  1.3134],\n",
       "         [-0.4580,  0.8889,  2.1198],\n",
       "         [-0.2403,  0.9707,  0.4943],\n",
       "         [-0.9818,  0.1897,  4.4617],\n",
       "         [-0.3090, -0.9511, -3.6256]]),\n",
       " tensor([[-1.1074],\n",
       "         [-1.1113],\n",
       "         [-1.4115],\n",
       "         [ 0.9659],\n",
       "         [-0.4208]]),\n",
       " tensor([[ -3.1504],\n",
       "         [ -4.6390],\n",
       "         [ -3.3152],\n",
       "         [-10.6982],\n",
       "         [ -4.8677]]),\n",
       " tensor([[-0.2464,  0.9692,  1.8883],\n",
       "         [-0.5702,  0.8215,  2.6199],\n",
       "         [-0.2891,  0.9573,  1.0106],\n",
       "         [-0.9989, -0.0465,  4.7489],\n",
       "         [-0.5092, -0.8606, -4.4020]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state[:5], action[:5], reward[:5], next_state[:5], over[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1350.4743205137393"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_update(model, model_next):\n",
    "    for param, param_next in zip(model.parameters(), model_next.parameters()):\n",
    "        #以一个小的比例更新\n",
    "        value = param_next.data * 0.995 + param.data * 0.005\n",
    "        param_next.data.copy_(value)\n",
    "\n",
    "\n",
    "soft_update(torch.nn.Linear(4, 64), torch.nn.Linear(4, 64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(-4.6052, requires_grad=True)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import math \n",
    "#这个alpha也是一个可学习的参数\n",
    "alpha = torch.tensor(math.log(0.01))\n",
    "alpha.requires_grad=True\n",
    "alpha"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![SAC.png](imgs/SAC%E7%AE%97%E6%B3%95.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #首先使用model_action计算动作和动作的熵\n",
    "    #[b, 4] -> [b, 1],[b, 1]\n",
    "    action, entropy = model_action(next_state)\n",
    "\n",
    "    #评估next_state的价值\n",
    "    #[b, 4],[b, 1] -> [b, 1]\n",
    "    target1 = model_value_next1(next_state, action)\n",
    "    target2 = model_value_next2(next_state, action)\n",
    "\n",
    "    #取价值小的,这是出于稳定性考虑\n",
    "    #[b, 1]\n",
    "    target = torch.min(target1, target2)\n",
    "\n",
    "    #exp和log互为反操作,这里是把alpha还原了\n",
    "    #这里的操作是在target上加上了动作的熵,alpha作为权重系数\n",
    "    #[b, 1] - [b, 1] -> [b, 1]\n",
    "    target += alpha.exp() * entropy\n",
    "\n",
    "    #[b, 1]\n",
    "    target *= 0.99\n",
    "    target *= (1 - over)\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.0366, grad_fn=<MeanBackward0>),\n",
       " tensor([[0.2545],\n",
       "         [0.2884],\n",
       "         [1.2884],\n",
       "         [0.2423],\n",
       "         [0.1255],\n",
       "         [1.5537],\n",
       "         [1.2793],\n",
       "         [0.3301],\n",
       "         [0.7353],\n",
       "         [0.2990],\n",
       "         [0.6835],\n",
       "         [0.4670],\n",
       "         [0.5638],\n",
       "         [0.1480],\n",
       "         [0.0869],\n",
       "         [0.0804],\n",
       "         [0.4897],\n",
       "         [0.2571],\n",
       "         [0.0275],\n",
       "         [0.5580],\n",
       "         [1.6905],\n",
       "         [0.5618],\n",
       "         [0.1329],\n",
       "         [0.1800],\n",
       "         [0.2614],\n",
       "         [0.4480],\n",
       "         [0.4295],\n",
       "         [0.4145],\n",
       "         [0.3923],\n",
       "         [0.5249],\n",
       "         [0.6440],\n",
       "         [0.5439],\n",
       "         [0.3217],\n",
       "         [0.3322],\n",
       "         [0.7698],\n",
       "         [0.5542],\n",
       "         [0.4917],\n",
       "         [0.4608],\n",
       "         [0.2854],\n",
       "         [0.7670],\n",
       "         [0.2398],\n",
       "         [0.3867],\n",
       "         [1.3694],\n",
       "         [0.2684],\n",
       "         [0.4941],\n",
       "         [0.4143],\n",
       "         [1.6786],\n",
       "         [0.4723],\n",
       "         [1.2965],\n",
       "         [1.2664],\n",
       "         [0.0564],\n",
       "         [0.3040],\n",
       "         [0.5423],\n",
       "         [0.5271],\n",
       "         [0.8222],\n",
       "         [0.3655],\n",
       "         [0.3307],\n",
       "         [0.2791],\n",
       "         [0.2562],\n",
       "         [0.3151],\n",
       "         [0.1190],\n",
       "         [0.5747],\n",
       "         [0.2836],\n",
       "         [0.7125]], grad_fn=<NegBackward0>))"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss_action(state):\n",
    "    #计算action和熵\n",
    "    #[b, 3] -> [b, 1],[b, 1]\n",
    "    action, entropy = model_action(state)\n",
    "\n",
    "    #使用两个value网络评估action的价值\n",
    "    #[b, 3],[b, 1] -> [b, 1]\n",
    "    value1 = model_value1(state, action)\n",
    "    value2 = model_value2(state, action)\n",
    "\n",
    "    #取价值小的,出于稳定性考虑\n",
    "    #[b, 1]\n",
    "    value = torch.min(value1, value2)\n",
    "\n",
    "    #alpha还原后乘以熵,这个值期望的是越大越好,但是这里是计算loss,所以符号取反\n",
    "    #[1] - [b, 1] -> [b, 1]\n",
    "    loss_action = -alpha.exp() * entropy\n",
    "\n",
    "    #减去value,所以value越大越好,这样loss就会越小\n",
    "    loss_action -= value\n",
    "\n",
    "    return loss_action.mean(), entropy\n",
    "\n",
    "\n",
    "get_loss_action(state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 0.009400197304785252 -1767.3688590655597\n",
      "10 2400 0.005681571085005999 -1419.3170014862415\n",
      "20 4400 0.0033800634555518627 -1592.504787148847\n",
      "30 6400 0.0019578449428081512 -1212.7534135907624\n",
      "40 8400 0.0012702337699010968 -441.15286511540654\n",
      "50 10400 0.0009943126933649182 -1036.265425090319\n",
      "60 12400 0.0007969569996930659 -772.8475545799033\n",
      "70 14400 0.000589747098274529 -756.5842890284275\n",
      "80 16400 0.00041159402462653816 -819.7396557828536\n",
      "90 18400 0.00029611389618366957 -562.9475449088609\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    optimizer_action = torch.optim.Adam(model_action.parameters(), lr=3e-4)\n",
    "    optimizer_value1 = torch.optim.Adam(model_value1.parameters(), lr=3e-3)\n",
    "    optimizer_value2 = torch.optim.Adam(model_value2.parameters(), lr=3e-3)\n",
    "\n",
    "    #alpha也是要更新的参数,所以这里要定义优化器\n",
    "    optimizer_alpha = torch.optim.Adam([alpha], lr=3e-4)\n",
    "\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(100):\n",
    "        #更新N条数据\n",
    "        update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #对reward偏移,为了便于训练\n",
    "            reward = (reward + 8) / 8\n",
    "\n",
    "            #计算target,这个target里已经考虑了动作的熵\n",
    "            #[b, 1]\n",
    "            target = get_target(reward, next_state, over)\n",
    "            target = target.detach()\n",
    "\n",
    "            #计算两个value\n",
    "            value1 = model_value1(state, action)\n",
    "            value2 = model_value2(state, action)\n",
    "\n",
    "            #计算两个loss,两个value的目标都是要贴近target\n",
    "            loss_value1 = loss_fn(value1, target)\n",
    "            loss_value2 = loss_fn(value2, target)\n",
    "\n",
    "            #更新参数\n",
    "            optimizer_value1.zero_grad()\n",
    "            loss_value1.backward()\n",
    "            optimizer_value1.step()\n",
    "\n",
    "            optimizer_value2.zero_grad()\n",
    "            loss_value2.backward()\n",
    "            optimizer_value2.step()\n",
    "\n",
    "            #使用model_value计算model_action的loss\n",
    "            loss_action, entropy = get_loss_action(state)\n",
    "            optimizer_action.zero_grad()\n",
    "            loss_action.backward()\n",
    "            optimizer_action.step()\n",
    "\n",
    "            #熵乘以alpha就是alpha的loss\n",
    "            #[b, 1] -> [1]\n",
    "            loss_alpha = (entropy + 1).detach() * alpha.exp()\n",
    "            loss_alpha = loss_alpha.mean()\n",
    "\n",
    "            #更新alpha值\n",
    "            optimizer_alpha.zero_grad()\n",
    "            loss_alpha.backward()\n",
    "            optimizer_alpha.step()\n",
    "\n",
    "            #增量更新next模型\n",
    "            soft_update(model_value1, model_value_next1)\n",
    "            soft_update(model_value2, model_value_next2)\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(10)]) / 10\n",
    "            print(epoch, len(datas), alpha.exp().item(), test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAisUlEQVR4nO3df3DU9YH/8ddns9lNQrIbEsyuORLhWxkhw4+2gLC11/otKdGmtlZuxlqqjMfp6AUPpMNUWsWpczNx7H3P6p3Fu+u0+EctHTpFKwe2uaChPWPASGoAzWG/aFJxE340uwmQzY99f/+g7NcVtAnZ7L43eT5mdsZ8Pu/95L0fTZ5+9vPJZx1jjBEAABZyZXoCAAB8FCIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALBWxiL11FNPadasWcrLy9OyZcu0f//+TE0FAGCpjETq5z//uTZu3KiHH35Yr7/+uhYtWqSamhr19PRkYjoAAEs5mbjB7LJly7R06VL967/+qyQpHo+roqJC9913nx544IF0TwcAYCl3ur/h4OCgWltbtXnz5sQyl8ul6upqNTc3X/I5sVhMsVgs8XU8Htfp06dVWloqx3EmfM4AgNQyxqivr0/l5eVyuT76Tb20R+rkyZMaGRlRIBBIWh4IBPTWW29d8jn19fX63ve+l47pAQDSqKurSzNnzvzI9WmP1OXYvHmzNm7cmPg6EomosrJSXV1d8vl8GZwZAOByRKNRVVRUqKio6GPHpT1SM2bMUE5Ojrq7u5OWd3d3KxgMXvI5Xq9XXq/3ouU+n49IAUAW+0unbNJ+dZ/H49HixYvV2NiYWBaPx9XY2KhQKJTu6QAALJaRt/s2btyoNWvWaMmSJbr22mv1gx/8QGfOnNGdd96ZiekAACyVkUjdeuutOnHihLZs2aJwOKxPfvKTevHFFy+6mAIAMLVl5O+kxisajcrv9ysSiXBOCgCy0Gh/j3PvPgCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWGnOk9u3bp5tuuknl5eVyHEfPPfdc0npjjLZs2aIrr7xS+fn5qq6u1tGjR5PGnD59WqtXr5bP51NxcbHWrl2r/v7+cb0QAMDkM+ZInTlzRosWLdJTTz11yfWPPfaYnnzyST399NNqaWnRtGnTVFNTo4GBgcSY1atX6/Dhw2poaNCuXbu0b98+3X333Zf/KgAAk5MZB0lm586dia/j8bgJBoPm+9//fmJZb2+v8Xq95mc/+5kxxpgjR44YSebAgQOJMXv27DGO45j33ntvVN83EokYSSYSiYxn+gCADBnt7/GUnpM6duyYwuGwqqurE8v8fr+WLVum5uZmSVJzc7OKi4u1ZMmSxJjq6mq5XC61tLRccruxWEzRaDTpAQCY/FIaqXA4LEkKBAJJywOBQGJdOBxWWVlZ0nq3262SkpLEmA+rr6+X3+9PPCoqKlI5bQCApbLi6r7NmzcrEokkHl1dXZmeEgAgDVIaqWAwKEnq7u5OWt7d3Z1YFwwG1dPTk7R+eHhYp0+fToz5MK/XK5/Pl/QAAEx+KY3U7NmzFQwG1djYmFgWjUbV0tKiUCgkSQqFQurt7VVra2tizN69exWPx7Vs2bJUTgcAkOXcY31Cf3+/3n777cTXx44dU1tbm0pKSlRZWakNGzboH//xHzVnzhzNnj1bDz30kMrLy3XzzTdLkubNm6cbbrhBd911l55++mkNDQ1p3bp1+vrXv67y8vKUvTAAwCQw1ssGX3rpJSPposeaNWuMMecvQ3/ooYdMIBAwXq/XrFixwnR0dCRt49SpU+a2224zhYWFxufzmTvvvNP09fWl/NJFAICdRvt73DHGmAw28rJEo1H5/X5FIhHOTwFAFhrt7/GsuLoPADA1ESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLXGfINZAONjjFEsHNbZo0c1ePKk4rGYXHl58paVadrcufKUlmZ6ioA1iBSQJsYYDff26uR//ZdO79unoVOnNHLunDQyIsftlis/X95AQKVf+IJKrr9eOdOmyXGcTE8byCgiBaTJuXfeUde//7v6jxyRPnRfZzM8rJG+Pp3t69PZP/xB0bY2Vfzd38n7ER8ECkwVnJMC0uBcV5e6/uM/1H/48EWBuogxihw4oD9u26bYhz7lGphqiBQwwYajUYV/8YvzR1CjZYwi+/erZ9eu828JAlMUkQImkInH1XvggE6/9JIUj4/tucPD6nn+eZ05elRZ+LFvQEoQKWACxWMxndi9e1zb6N658y+/RQhMUkQKmEBmeFhn3357XNvoP3JEw319KZoRkF2IFGC5+LlzeueJJ2TG+HYhMBkQKSALmOFhxQcHMz0NIO2IFJAFzNAQkcKURKSALGCGhxWPxTI9DSDtiBSQBeJDQzIcSWEKIlJAFuBIClMVkQKywHAkosETJzI9DSDtiBQw0VJwJ/PhSESDPT0pmAyQXYgUMIFcXq9Krr8+09MAshaRAiaS48hdVJTpWQBZi0gBE8hxHOXk52d6GkDWIlLARHIc5UyblulZAFmLSAETyXHkStGRlInH+cgOTDlECphIjqOcgoKUbCoei435M6mAbEekgAnkOI4cV2p+zEbOnZMZGUnJtoBsQaSALBEfGODjOjDlECkgS4ycO8fbfZhyiBSQJeK83YcpiEgBWeLs0aMaOXs209MA0opIARPMcbvluN3j3s7gyZN88CGmHCIFTLC8igoVzJmT6WkAWYlIARPM5XbL5fFkehpAViJSwARzPB65vN5MTwPISkQKmGCu3FyOpIDLRKSACea43XJyc1OzMe7fhymGSAETzEnhkdQwl6BjiiFSwARzHCclHyEvSSP9/SnZDpAtiBSQRfhjXkw1RArIIhxJYaohUkAWGTlzJtNTANKKSAFZ5E+//W2mpwCkFZEC0iC3uFhOTs64tzPU2zv+yQBZhEgBaVC0YIFypk3L9DSArEOkgDRw5edLKfoYeWAq4acGSIOc/Hw5RAoYM35qgDRw5eVxJAVcBn5qgDTIyc9PyYUTkiTu3YcphEgBaeDKyzt/e6RxMvE4d53AlEKkgDRwXK7U3L/PGMXPnRv/doAsQaSAbGIMR1KYUogUkEWMMRrhSApTCJECsgnnpDDFECkgi5h4XEOnTmV6GkDaECkgTfJnzRr3NszgoKJtbePeDpAtiBSQJkULF2Z6CkDWGVOk6uvrtXTpUhUVFamsrEw333yzOjo6ksYMDAyorq5OpaWlKiws1KpVq9Td3Z00prOzU7W1tSooKFBZWZk2bdqk4eHh8b8awGI5hYWZngKQdcYUqaamJtXV1enVV19VQ0ODhoaGtHLlSp35wAex3X///XrhhRe0Y8cONTU16fjx47rlllsS60dGRlRbW6vBwUG98soreuaZZ7Rt2zZt2bIlda8KsBB3QQfGzjHm8u+xcuLECZWVlampqUmf+9znFIlEdMUVV+jZZ5/V3/zN30iS3nrrLc2bN0/Nzc1avny59uzZoy9/+cs6fvy4AoGAJOnpp5/Wt7/9bZ04cUIej+cvft9oNCq/369IJCKfz3e50wfSqv/NN9Xx7W+PezvT//qv9b82bUrBjIDMGe3v8XGdk4pEIpKkkpISSVJra6uGhoZUXV2dGDN37lxVVlaqublZktTc3KwFCxYkAiVJNTU1ikajOnz48CW/TywWUzQaTXoA2SZlb/cZo3H8vyWQVS47UvF4XBs2bNB1112n+fPnS5LC4bA8Ho+Ki4uTxgYCAYXD4cSYDwbqwvoL6y6lvr5efr8/8aioqLjcaQMZk5Ofn5LtmOFhmaGhlGwLsN1lR6qurk6HDh3S9u3bUzmfS9q8ebMikUji0dXVNeHfE0i5VNy7T1J8aEjxwcGUbAuwnftynrRu3Trt2rVL+/bt08yZMxPLg8GgBgcH1dvbm3Q01d3drWAwmBizf//+pO1duPrvwpgP83q98nq9lzNVYNIxg4NEClPGmI6kjDFat26ddu7cqb1792r27NlJ6xcvXqzc3Fw1NjYmlnV0dKizs1OhUEiSFAqF1N7erp6ensSYhoYG+Xw+VVVVjee1AFNCfGhIhkhhihjTkVRdXZ2effZZPf/88yoqKkqcQ/L7/crPz5ff79fatWu1ceNGlZSUyOfz6b777lMoFNLy5cslSStXrlRVVZVuv/12PfbYYwqHw3rwwQdVV1fH0RIwCmZ4WHHOSWGKGNOR1NatWxWJRHT99dfryiuvTDx+/vOfJ8Y8/vjj+vKXv6xVq1bpc5/7nILBoH75y18m1ufk5GjXrl3KyclRKBTSN7/5Td1xxx165JFHUveqgElsKBLR0J/+lOlpAGkxrr+TyhT+TgrZaLivT0cfflhn33573Nu6at06zVi5MgWzAjIjLX8nBWD0nNxc5VVWZnoaQFYhUkCaOC6XcgoKMj0NIKsQKSBdHCdlf9ALTBVECkgTx+WSi0gBY0KkgHRxnJS93We4fx+mCCIFpIvjyJWivwWMx2ISkcIUQKSANHEcR06K7t83cvasTDyekm0BNiNSQBaKnzsnESlMAUQKyEIjAwOck8KUQKSALDRy9ixHUpgSiBSQhc794Q/nL54AJjkiBaSR96/+St7y8nFvZ+CPf+QzpTAlECkgjdyFhXIXFWV6GkDWIFJAGjm5uXJyczM9DSBrECkgjVwej1xEChg1IgWkkSs3V47Hk+lpAFmDSAFp5KTySIr792EKIFJAGjlud8rOSY2cO5eS7QA2I1JAGqXq3n2SNHLmTMq2BdjKnekJAFPdhbfs+oaHtf/ECb3b3y9J+uYnPiFvTs5HPo9IYSogUkAGGWPUNzSk5hMn9FZvr5bMmKHaigp5XS7luj7+jY7hP8cMmMyIFJBBJwYG9JvjxxXIy9Pd11yjfPfofyQ5ksJUwDkpIN3+fF7qT7GYXnzvPS2YPl0rysvHFChJirS0TMTsAKsQKSDNipcuVdzr1Ss9Pbra59OC6dPluowLKmLvvz8BswPswtt9QJrlFBXp1NCQTsdiuj4YvChQ7505o4OnT6tvaEhX5OUpdMUVmsZdKjBFESkgzXLy89URjWp2UVFSfIwxOtbfr4cPHtQ7/f0aGBmRLzdX86dP1z8tXZp0IYUxRnH+kBdTAG/3AWmWU1Cg9t5efbKkJGn5/+3v113//d96MxLRuZERGUmRoSH9d0+P1re06NTAQGLsYDyune++m+aZA+lHpIA0cxUUaMAYFX7oLbwfHD6syNDQJZ+z/+RJNRw/nvj6jdOn9fw776inp2dC5wpkGpEC0iwnP1+eMV7J90Ej8bgO9/Yq0ten5ubmFM4MsA/npIA0c3m9inu9l/38YWP0xfJynS0q0vLly1M4M8A+HEkBGTDtqqsUGxlJWlZbUaHcj7gUfVZhoRb++RyWNydHZ4eHdZXPp0AgMOFzBTKJSAEZ8KnPflbvB4NJy2rKy/Xwpz6lvJycxA9mjuOo1OvV/1m6VFXFxZLOX9l3JBLRX3/jG+mdNJABvN0HZMD/XrlSz7z8smZ6vXLHYpLO3yG9prxcMwsKtOuPf9SpgQHNKizUrbNnq/QDbw/+aXBQA0VFWnDjjSm9qzpgIyIFpJnjOAoGg7rys59V+NgxzTx2TPrzW3+O42j+9OmaP336JZ87Yoza+/v1ya9/Xf7S0nROG8gI3u4DMqCgoEAramr0usejwXnzpI/5SI4L4sbo2Nmz+tOiRbrhttuUy8fQYwogUkAGOI6jOXPmaNU3vqEX43G994lPSB9x9GSM0dl4XK/19+v3c+Zo/SOPaDpHUZgiiBSQIY7jaMGCBbr1jjt07MortSsnR+Grr1bOzJly5eVJjqOY16tDbreeGxpSzk036R8eflilpaWci8KUwTkpIIMcx9G8efN01VVX6d1339Vvdu/W821tGo7HFXe5VJCXpyXXXad/uPFGBYNBeTweAoUphUgBGeY4jqZNm6Z58+Zp3rx5HzsOmGqIFGAJIgRcjHNSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKw1pkht3bpVCxculM/nk8/nUygU0p49exLrBwYGVFdXp9LSUhUWFmrVqlXq7u5O2kZnZ6dqa2tVUFCgsrIybdq0ScPDw6l5NQCASWVMkZo5c6YeffRRtba26rXXXtMXvvAFffWrX9Xhw4clSffff79eeOEF7dixQ01NTTp+/LhuueWWxPNHRkZUW1urwcFBvfLKK3rmmWe0bds2bdmyJbWvCgAwOZhxmj59uvnRj35kent7TW5urtmxY0di3ZtvvmkkmebmZmOMMbt37zYul8uEw+HEmK1btxqfz2disdiov2ckEjGSTCQSGe/0AQAZMNrf45d9TmpkZETbt2/XmTNnFAqF1NraqqGhIVVXVyfGzJ07V5WVlWpubpYkNTc3a8GCBQoEAokxNTU1ikajiaOxS4nFYopGo0kPAMDkN+ZItbe3q7CwUF6vV/fcc4927typqqoqhcNheTweFRcXJ40PBAIKh8OSpHA4nBSoC+svrPso9fX18vv9iUdFRcVYpw0AyEJjjtQ111yjtrY2tbS06N5779WaNWt05MiRiZhbwubNmxWJRBKPrq6uCf1+AAA7uMf6BI/Ho6uvvlqStHjxYh04cEBPPPGEbr31Vg0ODqq3tzfpaKq7u1vBYFCSFAwGtX///qTtXbj678KYS/F6vfJ6vWOdKgAgy43776Ti8bhisZgWL16s3NxcNTY2JtZ1dHSos7NToVBIkhQKhdTe3q6enp7EmIaGBvl8PlVVVY13KgCASWZMR1KbN2/WjTfeqMrKSvX19enZZ5/Vyy+/rF//+tfy+/1au3atNm7cqJKSEvl8Pt13330KhUJavny5JGnlypWqqqrS7bffrscee0zhcFgPPvig6urqOFICAFxkTJHq6enRHXfcoffff19+v18LFy7Ur3/9a33xi1+UJD3++ONyuVxatWqVYrGYampq9MMf/jDx/JycHO3atUv33nuvQqGQpk2bpjVr1uiRRx5J7asCAEwKjjHGZHoSYxWNRuX3+xWJROTz+TI9HQDAGI329zj37gMAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrXFF6tFHH5XjONqwYUNi2cDAgOrq6lRaWqrCwkKtWrVK3d3dSc/r7OxUbW2tCgoKVFZWpk2bNml4eHg8UwEATEKXHakDBw7o3/7t37Rw4cKk5ffff79eeOEF7dixQ01NTTp+/LhuueWWxPqRkRHV1tZqcHBQr7zyip555hlt27ZNW7ZsufxXAQCYnMxl6OvrM3PmzDENDQ3m85//vFm/fr0xxpje3l6Tm5trduzYkRj75ptvGkmmubnZGGPM7t27jcvlMuFwODFm69atxufzmVgsNqrvH4lEjCQTiUQuZ/oAgAwb7e/xyzqSqqurU21traqrq5OWt7a2amhoKGn53LlzVVlZqebmZklSc3OzFixYoEAgkBhTU1OjaDSqw4cPX/L7xWIxRaPRpAcAYPJzj/UJ27dv1+uvv64DBw5ctC4cDsvj8ai4uDhpeSAQUDgcToz5YKAurL+w7lLq6+v1ve99b6xTBQBkuTEdSXV1dWn9+vX66U9/qry8vIma00U2b96sSCSSeHR1daXtewMAMmdMkWptbVVPT48+/elPy+12y+12q6mpSU8++aTcbrcCgYAGBwfV29ub9Lzu7m4Fg0FJUjAYvOhqvwtfXxjzYV6vVz6fL+kBAJj8xhSpFStWqL29XW1tbYnHkiVLtHr16sQ/5+bmqrGxMfGcjo4OdXZ2KhQKSZJCoZDa29vV09OTGNPQ0CCfz6eqqqoUvSwAwGQwpnNSRUVFmj9/ftKyadOmqbS0NLF87dq12rhxo0pKSuTz+XTfffcpFApp+fLlkqSVK1eqqqpKt99+ux577DGFw2E9+OCDqqurk9frTdHLAgBMBmO+cOIvefzxx+VyubRq1SrFYjHV1NTohz/8YWJ9Tk6Odu3apXvvvVehUEjTpk3TmjVr9Mgjj6R6KgCALOcYY0ymJzFW0WhUfr9fkUiE81MAkIVG+3uce/cBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKzlzvQELocxRpIUjUYzPBMAwOW48Pv7wu/zj5KVkTp16pQkqaKiIsMzAQCMR19fn/x+/0euz8pIlZSUSJI6Ozs/9sVNddFoVBUVFerq6pLP58v0dKzFfhod9tPosJ9Gxxijvr4+lZeXf+y4rIyUy3X+VJrf7+c/glHw+Xzsp1FgP40O+2l02E9/2WgOMrhwAgBgLSIFALBWVkbK6/Xq4YcfltfrzfRUrMZ+Gh320+iwn0aH/ZRajvlL1/8BAJAhWXkkBQCYGogUAMBaRAoAYC0iBQCwVlZG6qmnntKsWbOUl5enZcuWaf/+/ZmeUlrt27dPN910k8rLy+U4jp577rmk9cYYbdmyRVdeeaXy8/NVXV2to0ePJo05ffq0Vq9eLZ/Pp+LiYq1du1b9/f1pfBUTq76+XkuXLlVRUZHKysp08803q6OjI2nMwMCA6urqVFpaqsLCQq1atUrd3d1JYzo7O1VbW6uCggKVlZVp06ZNGh4eTudLmVBbt27VwoULE394GgqFtGfPnsR69tGlPfroo3IcRxs2bEgsY19NEJNltm/fbjwej/nxj39sDh8+bO666y5TXFxsuru7Mz21tNm9e7f57ne/a375y18aSWbnzp1J6x999FHj9/vNc889Z37/+9+br3zlK2b27Nnm3LlziTE33HCDWbRokXn11VfNb3/7W3P11Veb2267Lc2vZOLU1NSYn/zkJ+bQoUOmra3NfOlLXzKVlZWmv78/Meaee+4xFRUVprGx0bz22mtm+fLl5jOf+Uxi/fDwsJk/f76prq42Bw8eNLt37zYzZswwmzdvzsRLmhC/+tWvzH/+53+a//mf/zEdHR3mO9/5jsnNzTWHDh0yxrCPLmX//v1m1qxZZuHChWb9+vWJ5eyriZF1kbr22mtNXV1d4uuRkRFTXl5u6uvrMzirzPlwpOLxuAkGg+b73/9+Yllvb6/xer3mZz/7mTHGmCNHjhhJ5sCBA4kxe/bsMY7jmPfeey9tc0+nnp4eI8k0NTUZY87vk9zcXLNjx47EmDfffNNIMs3NzcaY8/8z4HK5TDgcTozZunWr8fl8JhaLpfcFpNH06dPNj370I/bRJfT19Zk5c+aYhoYG8/nPfz4RKfbVxMmqt/sGBwfV2tqq6urqxDKXy6Xq6mo1NzdncGb2OHbsmMLhcNI+8vv9WrZsWWIfNTc3q7i4WEuWLEmMqa6ulsvlUktLS9rnnA6RSETS/785cWtrq4aGhpL209y5c1VZWZm0nxYsWKBAIJAYU1NTo2g0qsOHD6dx9ukxMjKi7du368yZMwqFQuyjS6irq1NtbW3SPpH472kiZdUNZk+ePKmRkZGkf8mSFAgE9NZbb2VoVnYJh8OSdMl9dGFdOBxWWVlZ0nq3262SkpLEmMkkHo9rw4YNuu666zR//nxJ5/eBx+NRcXFx0tgP76dL7ccL6yaL9vZ2hUIhDQwMqLCwUDt37lRVVZXa2trYRx+wfft2vf766zpw4MBF6/jvaeJkVaSAy1FXV6dDhw7pd7/7XaanYqVrrrlGbW1tikQi+sUvfqE1a9aoqakp09OySldXl9avX6+Ghgbl5eVlejpTSla93Tdjxgzl5ORcdMVMd3e3gsFghmZllwv74eP2UTAYVE9PT9L64eFhnT59etLtx3Xr1mnXrl166aWXNHPmzMTyYDCowcFB9fb2Jo3/8H661H68sG6y8Hg8uvrqq7V48WLV19dr0aJFeuKJJ9hHH9Da2qqenh59+tOfltvtltvtVlNTk5588km53W4FAgH21QTJqkh5PB4tXrxYjY2NiWXxeFyNjY0KhUIZnJk9Zs+erWAwmLSPotGoWlpaEvsoFAqpt7dXra2tiTF79+5VPB7XsmXL0j7niWCM0bp167Rz507t3btXs2fPTlq/ePFi5ebmJu2njo4OdXZ2Ju2n9vb2pKA3NDTI5/OpqqoqPS8kA+LxuGKxGPvoA1asWKH29na1tbUlHkuWLNHq1asT/8y+miCZvnJjrLZv3268Xq/Ztm2bOXLkiLn77rtNcXFx0hUzk11fX585ePCgOXjwoJFk/vmf/9kcPHjQvPvuu8aY85egFxcXm+eff9688cYb5qtf/eolL0H/1Kc+ZVpaWszvfvc7M2fOnEl1Cfq9995r/H6/efnll83777+feJw9ezYx5p577jGVlZVm79695rXXXjOhUMiEQqHE+guXDK9cudK0tbWZF1980VxxxRWT6pLhBx54wDQ1NZljx46ZN954wzzwwAPGcRzzm9/8xhjDPvo4H7y6zxj21UTJukgZY8y//Mu/mMrKSuPxeMy1115rXn311UxPKa1eeuklI+mix5o1a4wx5y9Df+ihh0wgEDBer9esWLHCdHR0JG3j1KlT5rbbbjOFhYXG5/OZO++80/T19WXg1UyMS+0fSeYnP/lJYsy5c+fM3//935vp06ebgoIC87Wvfc28//77Sdt55513zI033mjy8/PNjBkzzLe+9S0zNDSU5lczcf72b//WXHXVVcbj8ZgrrrjCrFixIhEoY9hHH+fDkWJfTQw+qgMAYK2sOicFAJhaiBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALDW/wPZJ7W7lXygEQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-132.50931630808057"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Gym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
